/*
 * AIEngine a new generation network intrusion detection system.
 *
 * Copyright (C) 2013-2023  Luis Campo Giralte
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Library General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Library General Public License for more details.
 *
 * You should have received a copy of the GNU Library General Public
 * License along with this library; if not, write to the
 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
 * Boston, MA  02110-1301, USA.
 *
 * Written by Luis Campo Giralte <luis.camp0.2009@gmail.com>
 *
 */
#include "CoAPProtocol.h"
#include <iomanip>

namespace aiengine {

CoAPProtocol::~CoAPProtocol() {

	anomaly_.reset();
}

bool CoAPProtocol::check(const Packet &packet) {

	int length = packet.getLength();

	if (length >= header_size) {
		if ((packet.getSourcePort() == 5683)||(packet.getDestinationPort() == 5683)) {
			setHeader(packet.getPayload());
			++total_valid_packets_;
			return true;
		}
	}
	++total_invalid_packets_;
        return false;
}

void CoAPProtocol::setDynamicAllocatedMemory(bool value) {

	info_cache_->setDynamicAllocatedMemory(value);
	uri_cache_->setDynamicAllocatedMemory(value);
	host_cache_->setDynamicAllocatedMemory(value);
}

bool CoAPProtocol::isDynamicAllocatedMemory() const {

	return info_cache_->isDynamicAllocatedMemory();
}

uint64_t CoAPProtocol::getCurrentUseMemory() const {

        uint64_t mem = sizeof(CoAPProtocol);

        mem += info_cache_->getCurrentUseMemory();
        mem += uri_cache_->getCurrentUseMemory();
        mem += host_cache_->getCurrentUseMemory();

        return mem;
}

uint64_t CoAPProtocol::getAllocatedMemory() const {

        uint64_t mem = sizeof(CoAPProtocol);

        mem += host_cache_->getAllocatedMemory();
        mem += uri_cache_->getAllocatedMemory();
        mem += info_cache_->getAllocatedMemory();

        return mem;
}

uint64_t CoAPProtocol::getTotalAllocatedMemory() const {

        uint64_t mem = getAllocatedMemory();

	mem += compute_memory_used_by_maps();

	return mem;
}

int32_t CoAPProtocol::release_coap_info(CoAPInfo *info) {

        int32_t bytes_released = 0;

        bytes_released = releaseStringToCache(host_cache_, info->host_name);
        bytes_released += releaseStringToCache(uri_cache_, info->uri);

        return bytes_released;
}

uint64_t CoAPProtocol::compute_memory_used_by_maps() const {

	uint64_t bytes = 0;

	std::for_each (host_map_.begin(), host_map_.end(), [&bytes] (PairStringCacheHits const &dt) {
		bytes += dt.first.size();
	});
	std::for_each (uri_map_.begin(), uri_map_.end(), [&bytes] (PairStringCacheHits const &dt) {
		bytes += dt.first.size();
	});
	return bytes;
}

uint32_t CoAPProtocol::getTotalCacheMisses() const {

	uint32_t miss = 0;

	miss = info_cache_->getTotalFails();
	miss += host_cache_->getTotalFails();
	miss += uri_cache_->getTotalFails();

	return miss;
}

void CoAPProtocol::releaseCache() {

        if (FlowManagerPtr fm = flow_mng_.lock(); fm) {
                auto ft = fm->getFlowTable();

                std::ostringstream msg;
                msg << "Releasing " << name() << " cache";

                infoMessage(msg.str());

                uint64_t total_cache_bytes_released = compute_memory_used_by_maps();
                uint64_t total_bytes_released_by_flows = 0;
                uint64_t total_cache_save_bytes = 0;
                uint32_t release_flows = 0;
                uint32_t release_hosts = host_map_.size();
                uint32_t release_uris = uri_map_.size();

                for (auto &flow: ft) {
                        if (SharedPointer<CoAPInfo> info = flow->getCoAPInfo(); info) {
				total_bytes_released_by_flows += sizeof(info);

                                ++release_flows;
                                flow->layer7info.reset();
                                info_cache_->release(info);
                        }
                }
                // Some entries can be still on the maps and needs to be
                // retrieve to their existing caches
                for (auto &entry: host_map_) {
			total_cache_save_bytes += entry.second.sc->size() * (entry.second.hits - 1);
                        releaseStringToCache(host_cache_, entry.second.sc);
		}
                host_map_.clear();

                for (auto &entry: uri_map_) {
			total_cache_save_bytes += entry.second.sc->size() * (entry.second.hits - 1);
                        releaseStringToCache(uri_cache_, entry.second.sc);
		}
                uri_map_.clear();

		data_time_ = boost::posix_time::microsec_clock::local_time();

                msg.str("");
                msg << "Release " << release_hosts << " hosts, " << release_uris << " uris, ";
                msg << release_flows << " flows";
		computeMemoryUtilization(msg, total_cache_bytes_released, total_bytes_released_by_flows, total_cache_save_bytes);
		infoMessage(msg.str());
        }
}

void CoAPProtocol::releaseFlowInfo(Flow *flow) {

	if (auto info = flow->getCoAPInfo(); info) {
		release_coap_info(info.get());
		info_cache_->release(info);
	}
}

void CoAPProtocol::processFlow(Flow *flow) {

	CPUCycle cycles(&total_cpu_cycles_);
	setHeader(flow->packet->getPayload());
	int length = flow->packet->getLength();
	total_bytes_ += length;
	++total_packets_;
	++flow->total_packets_l7;

	if (length >= header_size) {
		setHeader(flow->packet->getPayload());
		if (getVersion() == 1) {
                	SharedPointer<CoAPInfo> info = flow->getCoAPInfo();
                	if (!info) {
                        	if (info = info_cache_->acquire(); !info) {
					logFailCache(info_cache_->name(), flow);
					return;
                        	}
                        	flow->layer7info = info;
                	}

			current_flow_ = flow;

			if (info->isBanned() == false) {
				uint8_t type __attribute__((unused)) = getType();
				uint8_t code = getCode();
				const uint8_t *payload = (uint8_t*)header_;
				int offset = sizeof(coap_header) + getTokenLength();

				boost::string_ref header(reinterpret_cast<const char*>(&payload[offset]), length - offset);

				// TODO anomaly for the size of the getTokenLength()
				if (code == COAP_CODE_GET) {
					++total_coap_gets_;
					process_common_header(info.get(), &payload[offset], length - offset);
				} else if (code == COAP_CODE_POST) {
					++total_coap_posts_;
					process_common_header(info.get(), &payload[offset], length - offset);
				} else if (code == COAP_CODE_PUT) {
					++total_coap_puts_;
					process_common_header(info.get(), &payload[offset], length - offset);
				} else if (code == COAP_CODE_DELETE) {
					++total_coap_deletes_;
					process_common_header(info.get(), &payload[offset], length - offset);
				} else {
					++total_coap_others_;
				}
			}
		}
	} else {
		++total_events_;
                if (flow->getPacketAnomaly() == PacketAnomalyType::NONE)
                        flow->setPacketAnomaly(PacketAnomalyType::COAP_BOGUS_HEADER);

                anomaly_->incAnomaly(PacketAnomalyType::COAP_BOGUS_HEADER);
	}
}

void CoAPProtocol::process_common_header(CoAPInfo *info, const uint8_t *payload, int length) {

	int offset = 0;
	int buffer_offset = 0;
	uint8_t type = 0;

	do {
		int data_offset = 0;
		const coap_ext_header *extension = reinterpret_cast <const coap_ext_header*> (&payload[offset]);
		int delta = (extension->deltalength >> 4);
		type += delta;
		int extension_length = (extension->deltalength & 0x0F);
		if (extension_length > 12 ) {
			extension_length += extension->data[0];
			++data_offset;
		}
		const char *dataptr = reinterpret_cast <const char*> (&(extension->data[data_offset]));
		if (type == COAP_OPTION_URI_HOST) { // The hostname
			boost::string_ref hostname(dataptr, extension_length);

        		if (ban_domain_mng_) {
                		if (auto host_candidate = ban_domain_mng_->getDomainName(hostname); host_candidate) {
                        		++total_ban_hosts_;
					info->setIsBanned(true);
                        		return;
                		}
        		}
        		++total_allow_hosts_;

			attach_host_to_flow(info, hostname);
		} else {
			if ((type == COAP_OPTION_LOCATION_PATH)or(type == COAP_OPTION_URI_PATH)) {
				// Copy the parts of the uri on a temp buffer
				if ((buffer_offset + extension_length + 1) < MAX_URI_BUFFER) {
					std::memcpy(uri_buffer_ + buffer_offset, "/", 1);
					++buffer_offset;
					std::memcpy(uri_buffer_ + buffer_offset, dataptr, extension_length);
					buffer_offset += extension_length;
				}
			}
		}
		if (extension->data[0] == 0xFF) // End of options marker
			break;

		offset += extension_length + data_offset + 1;
	} while (offset + (int)sizeof(coap_ext_header) < length);

	if (buffer_offset > 0) { // There is a uri
		boost::string_ref uri(uri_buffer_, buffer_offset);

		attach_uri(info, uri);
	}

	// Just verify the hostname on the first coap request
        if (current_flow_->total_packets_l7 == 1) {
        	if ((domain_mng_)and(info->host_name)) {
                	if (auto host_candidate = domain_mng_->getDomainName(info->host_name->name()); host_candidate) {
				++total_events_;
                               	info->matched_domain_name = host_candidate;
#if defined(BINDING)
                              	if (host_candidate->call.haveCallback())
                               		host_candidate->call.executeCallback(current_flow_);
#endif
    			}
  		}
	}

	if ((info->matched_domain_name)and(buffer_offset > 0)) {
        	if (SharedPointer<HTTPUriSet> uset = info->matched_domain_name->getHTTPUriSet(); uset) {
                	if (uset->lookupURI(info->uri->name())) {
				++total_events_;
#if defined(BINDING)
                        	if (uset->call.haveCallback())
                                	uset->call.executeCallback(current_flow_);
#endif
			}
		}
	}
}

void CoAPProtocol::attach_host_to_flow(CoAPInfo *info, const boost::string_ref &hostname) {

        SharedPointer<StringCache> host_ptr = info->host_name;

        if (!host_ptr) { // There is no Hostname attached
                if (StringMap::iterator it = host_map_.find(hostname); it != host_map_.end()) {
                        ++(it->second).hits;
                        info->host_name = (it->second).sc;
		} else {
                        if (host_ptr = host_cache_->acquire(); host_ptr) {
                                host_ptr->name(hostname.data(), hostname.length());
                                info->host_name = host_ptr;
                                host_map_.insert(std::make_pair(host_ptr->name(), host_ptr));
                        }
                }
        }
}

// The URI should be updated on every request
void CoAPProtocol::attach_uri(CoAPInfo *info, const boost::string_ref &uri) {

        if (StringMap::iterator it = uri_map_.find(uri); it != uri_map_.end()) {
                // Update the URI of the flow
                info->uri = (it->second).sc;
	} else {
                if (SharedPointer<StringCache> uri_ptr = uri_cache_->acquire(); uri_ptr) {
                        uri_ptr->name(uri.data(), uri.length());
                        info->uri = uri_ptr;
                        uri_map_.insert(std::make_pair(uri_ptr->name(), uri_ptr));
                }
        }
}

void CoAPProtocol::setDomainNameManager(const SharedPointer<DomainNameManager>& dnm) {

	if (domain_mng_)
               	domain_mng_->setPluggedToName("");

	if (dnm) {
       		domain_mng_ = dnm;
		domain_mng_->setPluggedToName(name());
	} else {
		domain_mng_.reset();
	}
}

void CoAPProtocol::statistics(std::basic_ostream<char>& out, int level, int32_t limit) const {

	showStatisticsHeader(out, level);

	if (level > 0) {
                if (ban_domain_mng_)
			out << "\t" << "Plugged banned domains from:" << ban_domain_mng_->name() << "\n";
                if (domain_mng_)
			out << "\t" << "Plugged domains from:" << domain_mng_->name() << "\n";
	}
	if (level > 3) {
		out << "\t" << "Total gets:             " << std::setw(10) << total_coap_gets_ << "\n"
			<< "\t" << "Total posts:            " << std::setw(10) << total_coap_posts_ << "\n"
			<< "\t" << "Total puts:             " << std::setw(10) << total_coap_puts_ << "\n"
			<< "\t" << "Total delete:           " << std::setw(10) << total_coap_deletes_ << "\n"
			<< "\t" << "Total others:           " << std::setw(10) << total_coap_others_ << std::endl;
	}
	if ((level > 5)and(flow_forwarder_.lock()))
		flow_forwarder_.lock()->statistics(out);
	if (level > 3) {
		info_cache_->statistics(out);
		host_cache_->statistics(out);
		uri_cache_->statistics(out);
		if (level > 4) {
			host_map_.show(out, "\t", limit);
			uri_map_.show(out, "\t", limit);
		}
	}
}

void CoAPProtocol::statistics(Json &out, int level) const {

	showStatisticsHeader(out, level);

        if (level > 3) {
		Json j;

		j["gets"] = total_coap_gets_;
		j["posts"] = total_coap_posts_;
		j["puts"] = total_coap_puts_;
		j["deletes"] = total_coap_deletes_;
		j["others"] = total_coap_others_;

		out["methods"] = j;
        }
}

void CoAPProtocol::increaseAllocatedMemory(int value) {

        info_cache_->create(value);
        host_cache_->create(value);
        uri_cache_->create(value);
}

void CoAPProtocol::decreaseAllocatedMemory(int value) {

        info_cache_->destroy(value);
        host_cache_->destroy(value);
        uri_cache_->destroy(value);
}

CounterMap CoAPProtocol::getCounters() const {
	CounterMap cm;

        cm.addKeyValue("packets",total_packets_);
        cm.addKeyValue("bytes", total_bytes_);
        cm.addKeyValue("gets", total_coap_gets_);
        cm.addKeyValue("posts", total_coap_posts_);
        cm.addKeyValue("puts", total_coap_puts_);
        cm.addKeyValue("deletes", total_coap_deletes_);
        cm.addKeyValue("others", total_coap_others_);

	return cm;
}

#if defined(PYTHON_BINDING) || defined(RUBY_BINDING)
#if defined(PYTHON_BINDING)
boost::python::dict CoAPProtocol::getCacheData(const std::string &name) const {
#elif defined(RUBY_BINDING)
VALUE CoAPProtocol::getCacheData(const std::string &name) const {
#endif
        if (boost::iequals(name, "host"))
		return addMapToHash(host_map_);
        else if (boost::iequals(name, "uri"))
		return addMapToHash(uri_map_);

	StringMap empty {"", ""};

        return addMapToHash(empty);
}

#if defined(PYTHON_BINDING)
SharedPointer<Cache<StringCache>> CoAPProtocol::getCache(const std::string &name) {

        if (boost::iequals(name, "host"))
                return host_cache_;
        else if (boost::iequals(name, "uri"))
                return uri_cache_;

        return nullptr;
}
#endif

#endif

void CoAPProtocol::statistics(Json &out, const std::string &map_name, int32_t limit) const {

        if (boost::iequals(map_name, "hosts")) {
		host_map_.show(out, limit);
		return;
        }

        if (boost::iequals(map_name, "uris")) {
		uri_map_.show(out, limit);
        }
}

void CoAPProtocol::resetCounters() {

	reset();

	total_events_ = 0;
        total_allow_hosts_ = 0;
        total_ban_hosts_ = 0;
        total_coap_gets_ = 0;
        total_coap_posts_ = 0;
        total_coap_puts_ = 0;
        total_coap_deletes_ = 0;
        total_coap_others_ = 0;
}

} // namespace aiengine
